{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    ")\n",
    "from transformers import TrainingArguments\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from transformers import Trainer\n",
    "\n",
    "\n",
    "import os\n",
    "os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# AG_News 英文分类数据集\n",
    "# ds = load_dataset(\"fancyzhx/ag_news\")\n",
    "\n",
    "## 中文分类数据集\n",
    "ds = load_dataset(\"lansinuote/ChnSentiCorp\")\n",
    "\n",
    "model_name = \"bert-base-chinese\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    model_name,\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "bert = AutoModelForSequenceClassification.from_pretrained(\n",
    "    model_name,\n",
    "    trust_remote_code=True,\n",
    "    num_labels=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_func(item):\n",
    "    global tokenizer\n",
    "    tokenized_inputs = tokenizer(\n",
    "        item[\"text\"],\n",
    "        max_length=512,\n",
    "        truncation=True,\n",
    "    )\n",
    "    return tokenized_inputs\n",
    "\n",
    "tokenized_datasets = ds.map(\n",
    "    tokenize_func,\n",
    "    batched=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 通过下述命令，查看 trainer.save_model 保存的是否是最好的模型权重。\n",
    "# 通过md5值和sha1判断是否为同一个文件\n",
    "\n",
    "# !find . -type f -name \"*.safetensors\" -exec sha1sum {} \\;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class BertCLS:\n",
    "    def __init__(self, model, train_dataset=None, eval_dataset=None, output_dir=\"output\", epoch=3):\n",
    "        self.model = model\n",
    "        self.train_dataset = train_dataset\n",
    "        self.eval_dataset = eval_dataset\n",
    "        self.args = self.get_args(output_dir, epoch)\n",
    "        from transformers import DataCollatorWithPadding\n",
    "        self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
    "        self.trainer = Trainer(\n",
    "            model=self.model,\n",
    "            args=self.args,\n",
    "            train_dataset=self.train_dataset,\n",
    "            eval_dataset=self.eval_dataset,\n",
    "            data_collator=self.data_collator,\n",
    "            # compute_metrics=compute_metrics,\n",
    "            tokenizer=tokenizer,\n",
    "        )\n",
    "        \n",
    "    def get_args(self, output_dir, epoch):\n",
    "        args = TrainingArguments(\n",
    "            output_dir=output_dir,\n",
    "            evaluation_strategy=\"epoch\",\n",
    "            save_strategy=\"epoch\",\n",
    "            save_total_limit=3,\n",
    "            learning_rate=2e-5,\n",
    "            num_train_epochs=epoch,\n",
    "            weight_decay=0.01,\n",
    "            per_device_train_batch_size=32,\n",
    "            per_device_eval_batch_size=16,\n",
    "            # logging_steps=16,\n",
    "            save_safetensors=True,\n",
    "            overwrite_output_dir=True,\n",
    "            load_best_model_at_end=True,\n",
    "        )\n",
    "        return args\n",
    "    \n",
    "    def set_args(self, args):\n",
    "        \"\"\"\n",
    "            从外部重新设置 TrainingArguments，args 更新后，trainer也进行更新\n",
    "        \"\"\"\n",
    "        self.args = args\n",
    "        \n",
    "        self.trainer = Trainer(\n",
    "            model=self.model,\n",
    "            args=self.args,\n",
    "            train_dataset=self.train_dataset,\n",
    "            eval_dataset=self.eval_dataset,\n",
    "            data_collator=self.data_collator,\n",
    "            # compute_metrics=compute_metrics,\n",
    "            tokenizer=tokenizer,\n",
    "        )\n",
    "        \n",
    "    def train(self, over_write=False):\n",
    "        best_model_path = os.path.join(self.args.output_dir, \"best_model\")\n",
    "        \n",
    "        if over_write:\n",
    "            self.trainer.train()\n",
    "            self.trainer.save_model()\n",
    "        elif not os.path.exists(best_model_path):\n",
    "            self.trainer.train()\n",
    "            self.trainer.save_model()\n",
    "        else:\n",
    "            print(f\"预训练权重 {best_model_path} 已存在，且over_write={over_write}。不启动模型训练！\")\n",
    "\n",
    "    def eval(self, eval_dataset):\n",
    "        predictions = self.trainer.predict(eval_dataset)\n",
    "        preds = np.argmax(predictions.predictions, axis=-1)\n",
    "        metric = evaluate.load(\"glue\", \"mrpc\")\n",
    "        return metric.compute(predictions=preds, references=predictions.label_ids)\n",
    "    \n",
    "    def pred(self, pred_dataset):\n",
    "        predictions = self.trainer.predict(pred_dataset)\n",
    "        preds = np.argmax(predictions.predictions, axis=-1)\n",
    "        return pred_dataset.add_column(\"pred\", preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['text', 'label'],\n",
       "    num_rows: 1200\n",
       "})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "bert_cls = BertCLS(\n",
    "    model=bert,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.5075, 'f1': 0.6729385722191478}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_cls.eval(tokenized_datasets[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='114' max='114' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [114/114 01:06, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Model Preparation Time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.288048</td>\n",
       "      <td>0.005900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.259681</td>\n",
       "      <td>0.005900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.260843</td>\n",
       "      <td>0.005900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "bert_cls.train(over_write=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.9566666666666667, 'f1': 0.9577922077922078}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_cls.eval(tokenized_datasets[\"test\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载 best_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不训练模型，加载本地模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "bert_cls = BertCLS(\n",
    "    model=AutoModelForSequenceClassification.from_pretrained(\"output/best_model\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.9341666666666667, 'f1': 0.9341117597998332}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_cls.eval(tokenized_datasets[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
